import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os

# Import SpikingJelly from local folder (go up two levels from models/ to project root)
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'spikingjelly'))
from spikingjelly.activation_based import neuron, surrogate, functional, layer


def get_surrogate_function(args):
    """
    Get surrogate function based on configuration
    
    Args:
        args: Configuration object with surrogate_function and surrogate_params
        
    Returns:
        Surrogate function instance
    """
    surrogate_name = getattr(args, 'surrogate_function', 'atan')
    params = getattr(args, 'surrogate_params', {})
    
    if surrogate_name == 'atan':
        alpha = params.get('atan_alpha', getattr(args, 'atan_alpha', 2.0))
        return surrogate.ATan(alpha=alpha)
    
    elif surrogate_name == 's2nn':
        alpha = params.get('s2nn_alpha', 4.0)
        beta = params.get('s2nn_beta', 1.0)
        return surrogate.S2NN(alpha=alpha, beta=beta)
    
    elif surrogate_name == 'sigmoid':
        alpha = params.get('sigmoid_alpha', 4.0)
        return surrogate.Sigmoid(alpha=alpha)
    
    elif surrogate_name == 'piecewise_quadratic':
        alpha = params.get('alpha', 2.0)
        return surrogate.PiecewiseQuadratic(alpha=alpha)
    
    elif surrogate_name == 'soft_sign':
        alpha = params.get('alpha', 2.0)
        return surrogate.SoftSign(alpha=alpha)
    
    elif surrogate_name == 'piecewise_exp':
        alpha = params.get('alpha', 2.0)
        return surrogate.PiecewiseExp(alpha=alpha)
    
    elif surrogate_name == 'leaky_k_relu':
        leak = params.get('leak', 0.1)
        k = params.get('k', 1.0)
        return surrogate.LeakyKReLU(leak=leak, k=k)
    
    else:
        print(f"⚠️  Unknown surrogate function: {surrogate_name}, using ATan as fallback")
        alpha = params.get('atan_alpha', getattr(args, 'atan_alpha', 2.0))
        return surrogate.ATan(alpha=alpha)


class SpikingC2Linear(nn.Module):
    """Spiking version of C2Linear with optional readout branches"""
    def __init__(self, args, in_features, out_features, fw_bias=False, bw_bias=False):
        super(SpikingC2Linear, self).__init__()
        
        # Get backend
        backend = getattr(args, 'backend', 'torch')
        self.use_readout = getattr(args, 'use_spike_readout', False)
        
        # Linear layers using SpikingJelly
        self.forward_layer = layer.Linear(in_features, out_features, bias=fw_bias, step_mode='m')
        self.backward_layer = layer.Linear(out_features, in_features, bias=bw_bias, step_mode='m')
        
        # LIF neurons for activation - using configurable surrogate function
        surrogate_func = get_surrogate_function(args)
        self.forward_lif = neuron.LIFNode(
            tau=args.tau,
            v_threshold=args.v_threshold_forward,
            surrogate_function=surrogate_func,
            step_mode='m',
            backend=backend
        )
        # Create separate instance for backward path
        surrogate_func_bw = get_surrogate_function(args)
        self.backward_lif = neuron.LIFNode(
            tau=args.tau,
            v_threshold=args.v_threshold_backward,
            surrogate_function=surrogate_func_bw,
            step_mode='m',
            backend=backend
        )
        
        # Batchnorm settings
        self.fw_bn = args.fw_bn
        self.bw_bn = args.bw_bn
        
        # BatchNorm setup exactly matching bidirectional network structure (with affine parameter)
        if self.fw_bn == 1: 
            self.forward_bn = layer.BatchNorm1d(in_features, affine=args.bn_affine, step_mode='m')
        elif self.fw_bn == 2: 
            self.forward_bn = layer.BatchNorm1d(out_features, affine=args.bn_affine, step_mode='m')
        if self.bw_bn == 1: 
            self.backward_bn = layer.BatchNorm1d(out_features, affine=args.bn_affine, step_mode='m')
        elif self.bw_bn == 2: 
            self.backward_bn = layer.BatchNorm1d(in_features, affine=args.bn_affine, step_mode='m')
        
        # Readout branches (for training only) - convert spikes back to floats for BSD loss
        if self.use_readout:
            # Get expansion factor for richer feature representation
            expand_factor = getattr(args, 'readout_expand_factor', 1)
            
            # Forward readout: LIF spikes → FC → BN → float features (expanded)
            forward_out_features = int(out_features * expand_factor)
            backward_out_features = int(in_features * expand_factor)
            
            self.forward_readout_fc = layer.Linear(out_features, forward_out_features, bias=False, step_mode='m')
            self.forward_readout_bn = layer.BatchNorm1d(forward_out_features, affine=args.bn_affine, step_mode='m')
            
            # Backward readout: LIF spikes → FC → BN → float features (expanded)
            self.backward_readout_fc = layer.Linear(in_features, backward_out_features, bias=False, step_mode='m')
            self.backward_readout_bn = layer.BatchNorm1d(backward_out_features, affine=args.bn_affine, step_mode='m')
        
        # Bias initialization (for the bidirectional network)
        if args.bias_init == "zero" and self.forward_layer.bias is not None:
            self.forward_layer.bias.data.zero_()
        if args.bias_init == "zero" and self.backward_layer.bias is not None:
            self.backward_layer.bias.data.zero_()
    
    def get_parameters(self):
        self.forward_params = list(self.forward_layer.parameters())
        self.backward_params = list(self.backward_layer.parameters())
        
        # Include readout parameters if enabled
        if self.use_readout:
            self.forward_params += list(self.forward_readout_fc.parameters())
            self.forward_params += list(self.forward_readout_bn.parameters())
            self.backward_params += list(self.backward_readout_fc.parameters())
            self.backward_params += list(self.backward_readout_bn.parameters())
        
        return self.forward_params, self.backward_params
    
    def forward(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
        """
        Forward pass with optional readout branch and pre-LIF values

        Args:
            x: Input tensor
            detach_grad: Whether to detach gradients for local learning (default: False)
                        When True, gradients are blocked from backpropagating to previous layers,
                        ensuring locality in the learning algorithm (key for BSD/local learning)
            act: Whether to apply activation
            return_readout: Whether to return readout features (for BSD loss)
            return_prelif: Whether to return pre-LIF values (for alternative loss computation)

        Returns:
            If return_readout=False and return_prelif=False: spike features (for network inference)
            If return_readout=True: (spike_features, readout_features)
            If return_prelif=True: (spike_features, pre_lif_features) or (spike_features, pre_lif_features, readout_features)
        """
        # IMPORTANT: Gradient detachment for local learning
        # When detach_grad=True, we stop gradient flow from this layer to previous layers.
        # This enforces locality: each layer updates independently based only on local signals.
        # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
        if detach_grad:
            x = x.detach()
        
        if not act: 
            linear_output = self.forward_layer(x)
            if return_prelif and return_readout:
                return linear_output, linear_output, None  # (spike, prelif, readout)
            elif return_prelif:
                return linear_output, linear_output  # (spike, prelif) - both same when no activation
            elif return_readout:
                return linear_output, None  # (spike, readout)
            return linear_output
        
        # Forward pass following original structure
        if self.fw_bn == 0: 
            x = self.forward_layer(x)
        elif self.fw_bn == 1: 
            x = self.forward_layer(self.forward_bn(x))
        elif self.fw_bn == 2: 
            x = self.forward_bn(self.forward_layer(x))
            
        # Store pre-LIF value for potential return
        pre_lif = x.clone() if return_prelif else None
            
        # Apply LIF activation - this gives spikes
        spike_features = self.forward_lif(x)
        
        # Handle different return combinations
        if return_prelif and return_readout:
            # Return spike, prelif, and readout
            if self.use_readout:
                readout_features = self.forward_readout_fc(spike_features)
                readout_features = self.forward_readout_bn(readout_features)
                return spike_features, pre_lif, readout_features
            else:
                return spike_features, pre_lif, None
        elif return_prelif:
            # Return spike and prelif only
            return spike_features, pre_lif
        elif return_readout:
            # Original readout logic
            if self.use_readout:
                readout_features = self.forward_readout_fc(spike_features)
                readout_features = self.forward_readout_bn(readout_features)
                return spike_features, readout_features
            else:
                return spike_features, None
        
        return spike_features
    
    def reverse(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
        """
        Reverse pass with optional readout branch and pre-LIF values

        Args:
            x: Input tensor
            detach_grad: Whether to detach gradients for local learning (default: False)
                        When True, gradients are blocked from backpropagating to previous layers,
                        ensuring locality in the learning algorithm (key for BSD/local learning)
            act: Whether to apply activation
            return_readout: Whether to return readout features (for BSD loss)
            return_prelif: Whether to return pre-LIF values (for alternative loss computation)

        Returns:
            If return_readout=False and return_prelif=False: spike features (for network inference)
            If return_readout=True: (spike_features, readout_features)
            If return_prelif=True: (spike_features, pre_lif_features) or (spike_features, pre_lif_features, readout_features)
        """
        # IMPORTANT: Gradient detachment for local learning
        # When detach_grad=True, we stop gradient flow from this layer to previous layers.
        # This enforces locality: each layer updates independently based only on local signals.
        # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
        if detach_grad:
            x = x.detach()
            
        if not act: 
            linear_output = self.backward_layer(x)
            if return_prelif and return_readout:
                return linear_output, linear_output, None  # (spike, prelif, readout)
            elif return_prelif:
                return linear_output, linear_output  # (spike, prelif) - both same when no activation
            elif return_readout:
                return linear_output, None  # (spike, readout)
            return linear_output
        
        # Backward pass following original structure
        if self.bw_bn == 0: 
            x = self.backward_layer(x)
        elif self.bw_bn == 1: 
            x = self.backward_layer(self.backward_bn(x))
        elif self.bw_bn == 2: 
            x = self.backward_bn(self.backward_layer(x))
            
        # Store pre-LIF value for potential return
        pre_lif = x.clone() if return_prelif else None
            
        # Apply LIF activation - this gives spikes
        spike_features = self.backward_lif(x)
        
        # Handle different return combinations
        if return_prelif and return_readout:
            # Return spike, prelif, and readout
            if self.use_readout:
                readout_features = self.backward_readout_fc(spike_features)  
                readout_features = self.backward_readout_bn(readout_features)
                return spike_features, pre_lif, readout_features
            else:
                return spike_features, pre_lif, None
        elif return_prelif:
            # Return spike and prelif only
            return spike_features, pre_lif
        elif return_readout:
            # Original readout logic
            if self.use_readout:
                readout_features = self.backward_readout_fc(spike_features)  
                readout_features = self.backward_readout_bn(readout_features)
                return spike_features, readout_features
            else:
                return spike_features, None
        
        return spike_features


class SpikingC2Conv(nn.Module):
    """Spiking version of C2Conv with optional readout branches"""
    def __init__(self, args, in_channels, out_channels, kernel, fw_bias=False, bw_bias=False):
        super(SpikingC2Conv, self).__init__()
        
        # Get backend
        backend = getattr(args, 'backend', 'torch')
        self.use_readout = getattr(args, 'use_spike_readout', False)
        
        # Convolutional layers using SpikingJelly (exactly matching original C2Conv)
        self.forward_layer = layer.Conv2d(in_channels, out_channels, kernel, stride=2, 
                                         padding=kernel//2, bias=True, step_mode='m')
        self.backward_layer = layer.Conv2d(out_channels, in_channels, kernel, stride=1, 
                                          padding=kernel//2, bias=True, step_mode='m')
        self.upsample = layer.Upsample(scale_factor=2, mode='nearest', step_mode='m')
        
        # Batchnorm settings
        self.fw_bn = args.fw_bn
        self.bw_bn = args.bw_bn

        # BatchNorm setup
        if self.fw_bn == 1:
            self.forward_bn = layer.BatchNorm2d(in_channels, affine=args.bn_affine, step_mode='m')
        elif self.fw_bn == 2:
            self.forward_bn = layer.BatchNorm2d(out_channels, affine=args.bn_affine, step_mode='m')
        if self.bw_bn == 1:
            self.backward_bn = layer.BatchNorm2d(out_channels, affine=args.bn_affine, step_mode='m')
        elif self.bw_bn == 2:
            self.backward_bn = layer.BatchNorm2d(in_channels, affine=args.bn_affine, step_mode='m')

        # LIF neurons for activation - using configurable surrogate function
        surrogate_func = get_surrogate_function(args)
        self.forward_lif = neuron.LIFNode(
            tau=args.tau,
            v_threshold=args.v_threshold_forward,
            surrogate_function=surrogate_func,
            step_mode='m',
            backend=backend
        )
        # Create separate instance for backward path
        surrogate_func_bw = get_surrogate_function(args)
        self.backward_lif = neuron.LIFNode(
            tau=args.tau,
            v_threshold=args.v_threshold_backward,
            surrogate_function=surrogate_func_bw,
            step_mode='m',
            backend=backend
        )

        # Readout branches (for training only) - convert spikes to float features for BSD loss
        if self.use_readout:
            # Get expansion factor for richer feature representation
            expand_factor = getattr(args, 'readout_expand_factor', 1)
            
            # For conv layers: use 1x1 conv to expand channel dimensions while keeping spatial dimensions
            # Forward readout: LIF spikes → Conv1x1 → BN → float features (expanded channels)
            forward_out_channels = int(out_channels * expand_factor)
            backward_out_channels = int(in_channels * expand_factor)
            
            self.forward_readout_conv = layer.Conv2d(out_channels, forward_out_channels, 1, bias=False, step_mode='m')
            self.forward_readout_bn = layer.BatchNorm2d(forward_out_channels, affine=args.bn_affine, step_mode='m')
            
            # Backward readout: LIF spikes → Conv1x1 → BN → float features (expanded channels)
            self.backward_readout_conv = layer.Conv2d(in_channels, backward_out_channels, 1, bias=False, step_mode='m')
            self.backward_readout_bn = layer.BatchNorm2d(backward_out_channels, affine=args.bn_affine, step_mode='m')
        
        # Bias initialization (for the bidirectional network)
        if args.bias_init == "zero" and self.forward_layer.bias is not None:
            self.forward_layer.bias.data.zero_()
        if args.bias_init == "zero" and self.backward_layer.bias is not None:
            self.backward_layer.bias.data.zero_()
    
    def get_parameters(self):
        self.forward_params = list(self.forward_layer.parameters())
        self.backward_params = list(self.backward_layer.parameters())
        
        # Include readout parameters if enabled
        if self.use_readout:
            self.forward_params += list(self.forward_readout_conv.parameters())
            self.forward_params += list(self.forward_readout_bn.parameters())
            self.backward_params += list(self.backward_readout_conv.parameters())
            self.backward_params += list(self.backward_readout_bn.parameters())
        
        return self.forward_params, self.backward_params
    
    def forward(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
        """
        Forward pass with optional readout branch and pre-LIF values

        Args:
            x: Input tensor [T, N, C, H, W]
            detach_grad: Whether to detach gradients for local learning (default: False)
                        When True, gradients are blocked from backpropagating to previous layers,
                        ensuring locality in the learning algorithm (key for BSD/local learning)
            act: Whether to apply activation
            return_readout: Whether to return readout features (for BSD loss)
            return_prelif: Whether to return pre-LIF values (for alternative loss computation)

        Returns:
            If return_readout=False and return_prelif=False: spike features
            If return_readout=True: (spike_features, readout_features)
            If return_prelif=True: (spike_features, pre_lif_features) or (spike_features, pre_lif_features, readout_features)
        """
        # IMPORTANT: Gradient detachment for local learning
        # When detach_grad=True, we stop gradient flow from this layer to previous layers.
        # This enforces locality: each layer updates independently based only on local signals.
        # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
        if detach_grad:
            x = x.detach()
            
        if not act: 
            # No activation case
            linear_output = self.forward_layer(x)
            if return_prelif and return_readout:
                return linear_output, linear_output, None  # (spike, prelif, readout)
            elif return_prelif:
                return linear_output, linear_output  # (spike, prelif) - both same when no activation
            elif return_readout:
                return linear_output, None  # (spike, readout)
            return linear_output
        
        # Forward pass following bidirectional network structure
        if self.fw_bn == 0: 
            x = self.forward_layer(x)
        elif self.fw_bn == 1: 
            x = self.forward_layer(self.forward_bn(x))
        elif self.fw_bn == 2: 
            x = self.forward_bn(self.forward_layer(x))
            
        # Store pre-LIF value for potential return
        pre_lif = x.clone() if return_prelif else None
            
        # Apply LIF activation - this gives spikes
        spike_features = self.forward_lif(x)
        
        # Handle different return combinations
        if return_prelif and return_readout:
            # Return spike, prelif, and readout
            if self.use_readout:
                readout_features = self.forward_readout_conv(spike_features)
                readout_features = self.forward_readout_bn(readout_features)
                return spike_features, pre_lif, readout_features
            else:
                return spike_features, pre_lif, None
        elif return_prelif:
            # Return spike and prelif only
            return spike_features, pre_lif
        elif return_readout:
            # Original readout logic
            if self.use_readout:
                readout_features = self.forward_readout_conv(spike_features)
                readout_features = self.forward_readout_bn(readout_features)
                return spike_features, readout_features
            else:
                return spike_features, None
        
        return spike_features
    
    def reverse(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
        """
        Reverse pass with optional readout branch and pre-LIF values

        Args:
            x: Input tensor [T, N, C, H, W]
            detach_grad: Whether to detach gradients for local learning (default: False)
                        When True, gradients are blocked from backpropagating to previous layers,
                        ensuring locality in the learning algorithm (key for BSD/local learning)
            act: Whether to apply activation
            return_readout: Whether to return readout features (for BSD loss)
            return_prelif: Whether to return pre-LIF values (for alternative loss computation)

        Returns:
            If return_readout=False and return_prelif=False: spike features
            If return_readout=True: (spike_features, readout_features)
            If return_prelif=True: (spike_features, pre_lif_features) or (spike_features, pre_lif_features, readout_features)
        """
        # IMPORTANT: Gradient detachment for local learning
        # When detach_grad=True, we stop gradient flow from this layer to previous layers.
        # This enforces locality: each layer updates independently based only on local signals.
        # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
        if detach_grad:
            x = x.detach()
        
        # Upsample first (exactly matching original C2Conv)
        x = self.upsample(x)
        
        if not act: 
            # No activation case
            linear_output = self.backward_layer(x)
            if return_prelif and return_readout:
                return linear_output, linear_output, None  # (spike, prelif, readout)
            elif return_prelif:
                return linear_output, linear_output  # (spike, prelif) - both same when no activation
            elif return_readout:
                return linear_output, None  # (spike, readout)
            return linear_output
        
        # Reverse pass following bidirectional network structure
        if self.bw_bn == 0: 
            x = self.backward_layer(x)
        elif self.bw_bn == 1: 
            x = self.backward_layer(self.backward_bn(x))
        elif self.bw_bn == 2: 
            x = self.backward_bn(self.backward_layer(x))
            
        # Store pre-LIF value for potential return
        pre_lif = x.clone() if return_prelif else None
            
        # Apply LIF activation - this gives spikes
        spike_features = self.backward_lif(x)
        
        # Handle different return combinations
        if return_prelif and return_readout:
            # Return spike, prelif, and readout
            if self.use_readout:
                readout_features = self.backward_readout_conv(spike_features)
                readout_features = self.backward_readout_bn(readout_features)
                return spike_features, pre_lif, readout_features
            else:
                return spike_features, pre_lif, None
        elif return_prelif:
            # Return spike and prelif only
            return spike_features, pre_lif
        elif return_readout:
            # Original readout logic
            if self.use_readout:
                readout_features = self.backward_readout_conv(spike_features)
                readout_features = self.backward_readout_bn(readout_features)
                return spike_features, readout_features
            else:
                return spike_features, None
        
        return spike_features


# Other classes (SpikingC2View, SpikingC2Pool, SpikingC2ConvStride1) remain the same
# since they don't have LIF neurons and don't need readout branches

class SpikingC2View(nn.Module):
    """Spiking version of C2View - exactly following original structure with manual reshape"""
    def __init__(self, input_shape, output_shape):
        super(SpikingC2View, self).__init__()
        self.input_shape = input_shape if not isinstance(input_shape, int) else [input_shape]
        self.output_shape = output_shape if not isinstance(output_shape, int) else [output_shape]
    
    def forward(self, x, detach_grad=False, return_readout=False, return_prelif=False):
        # x: [T, N, C, H, W] -> [T, N, features] manually  
        T = x.shape[0]
        N = x.shape[1]
        result = x.view([T, N] + list(self.output_shape))
        
        # Handle different return combinations (no LIF so prelif = output)
        if return_prelif and return_readout:
            return result, result, None  # (output, prelif, readout)
        elif return_prelif:
            return result, result  # (output, prelif) - same for view layer
        elif return_readout:
            return result, None  # (output, readout)
        return result
    
    def reverse(self, x, detach_grad=False, return_readout=False, return_prelif=False):
        # x: [T, N, features] -> [T, N, C, H, W]
        T = x.shape[0]
        N = x.shape[1]
        result = x.view([T, N] + list(self.input_shape))
        
        # Handle different return combinations (no LIF so prelif = output)
        if return_prelif and return_readout:
            return result, result, None  # (output, prelif, readout)
        elif return_prelif:
            return result, result  # (output, prelif) - same for view layer
        elif return_readout:
            return result, None  # (output, readout)
        return result
    
    def get_parameters(self):
        # No parameters
        return [], []


class SpikingC2Pool(nn.Module):
    """Spiking version of C2Pool - exactly matching original (ignores parameters, uses fixed 2,2)"""
    def __init__(self, args, kernel_size, stride, padding):
        super(SpikingC2Pool, self).__init__()
        # Exactly matching original C2Pool - hardcoded (2,2) regardless of parameters
        self.max_pool = layer.MaxPool2d(2, 2, step_mode='m')
        self.upsample = layer.Upsample(scale_factor=2, mode='nearest', step_mode='m')
    
    def forward(self, x, detach_grad=False, return_readout=False, return_prelif=False):
        result = self.max_pool(x)
        
        # Handle different return combinations (no LIF so prelif = output)
        if return_prelif and return_readout:
            return result, result, None  # (output, prelif, readout)
        elif return_prelif:
            return result, result  # (output, prelif) - same for pool layer
        elif return_readout:
            return result, None  # (output, readout)
        return result
    
    def reverse(self, x, detach_grad=False, return_readout=False, return_prelif=False):
        # In the bidirectional network structure, reverse path is independent, so upsample is appropriate
        result = self.upsample(x)
        
        # Handle different return combinations (no LIF so prelif = output)
        if return_prelif and return_readout:
            return result, result, None  # (output, prelif, readout)
        elif return_prelif:
            return result, result  # (output, prelif) - same for pool layer
        elif return_readout:
            return result, None  # (output, readout)
        return result
    
    def get_parameters(self):
        # No parameters
        return [], []


class SpikingC2ConvStride1(nn.Module):
    """Spiking version of C2ConvStride1 with optional readout branches"""
    def __init__(self, args, in_channels, out_channels, kernel, fw_bias=False, bw_bias=False):
        super(SpikingC2ConvStride1, self).__init__()
        
        # Get backend
        backend = getattr(args, 'backend', 'torch')
        self.use_readout = getattr(args, 'use_spike_readout', False)
        
        # Convolutional layers with stride=1 (following original)
        self.forward_layer = layer.Conv2d(in_channels, out_channels, kernel, stride=1,
                                         padding=kernel//2, bias=fw_bias, step_mode='m')
        self.backward_layer = layer.Conv2d(out_channels, in_channels, kernel, stride=1,
                                          padding=kernel//2, bias=bw_bias, step_mode='m')
        
        # Batchnorm settings (following bidirectional network structure)
        self.fw_bn = args.fw_bn
        self.bw_bn = args.bw_bn
        
        # BatchNorm setup exactly matching bidirectional network structure (with affine parameter)
        if self.fw_bn == 1: 
            self.forward_bn = layer.BatchNorm2d(in_channels, affine=args.bn_affine, step_mode='m')
        elif self.fw_bn == 2: 
            self.forward_bn = layer.BatchNorm2d(out_channels, affine=args.bn_affine, step_mode='m')
        if self.bw_bn == 1: 
            self.backward_bn = layer.BatchNorm2d(out_channels, affine=args.bn_affine, step_mode='m')
        elif self.bw_bn == 2: 
            self.backward_bn = layer.BatchNorm2d(in_channels, affine=args.bn_affine, step_mode='m')
        
        # LIF neurons - using configurable surrogate function
        surrogate_func = get_surrogate_function(args)
        self.forward_lif = neuron.LIFNode(
            tau=args.tau,
            v_threshold=args.v_threshold_forward,
            surrogate_function=surrogate_func,
            step_mode='m',
            backend=backend
        )
        # Create separate instance for backward path
        surrogate_func_bw = get_surrogate_function(args)
        self.backward_lif = neuron.LIFNode(
            tau=args.tau,
            v_threshold=args.v_threshold_backward,
            surrogate_function=surrogate_func_bw,
            step_mode='m',
            backend=backend
        )
        
        # Readout branches (for training only) - convert spikes back to floats for BSD loss
        if self.use_readout:
            # Get expansion factor for richer feature representation
            expand_factor = getattr(args, 'readout_expand_factor', 1)
            
            # For conv layers: use 1x1 conv to expand channel dimensions while keeping spatial dimensions
            # Forward readout: LIF spikes → Conv1x1 → BN → float features (expanded channels)
            forward_out_channels = int(out_channels * expand_factor)
            backward_out_channels = int(in_channels * expand_factor)
            
            self.forward_readout_conv = layer.Conv2d(out_channels, forward_out_channels, 1, bias=False, step_mode='m')
            self.forward_readout_bn = layer.BatchNorm2d(forward_out_channels, affine=args.bn_affine, step_mode='m')
            
            # Backward readout: LIF spikes → Conv1x1 → BN → float features (expanded channels)
            self.backward_readout_conv = layer.Conv2d(in_channels, backward_out_channels, 1, bias=False, step_mode='m')
            self.backward_readout_bn = layer.BatchNorm2d(backward_out_channels, affine=args.bn_affine, step_mode='m')
        
        # Bias initialization (for the bidirectional network)
        if args.bias_init == "zero" and self.forward_layer.bias is not None:
            self.forward_layer.bias.data.zero_()
        if args.bias_init == "zero" and self.backward_layer.bias is not None:
            self.backward_layer.bias.data.zero_()
    
    def get_parameters(self):
        self.forward_params = list(self.forward_layer.parameters())
        self.backward_params = list(self.backward_layer.parameters())
        
        # Include readout parameters if enabled
        if self.use_readout:
            self.forward_params += list(self.forward_readout_conv.parameters())
            self.forward_params += list(self.forward_readout_bn.parameters())
            self.backward_params += list(self.backward_readout_conv.parameters())
            self.backward_params += list(self.backward_readout_bn.parameters())
        
        return self.forward_params, self.backward_params
    
    def forward(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
        """
        Forward pass with optional readout branch and pre-LIF values

        Args:
            x: Input tensor [T, N, C, H, W]
            detach_grad: Whether to detach gradients for local learning (default: False)
                        When True, gradients are blocked from backpropagating to previous layers,
                        ensuring locality in the learning algorithm (key for BSD/local learning)
            act: Whether to apply activation
            return_readout: Whether to return readout features (for BSD loss)
            return_prelif: Whether to return pre-LIF values (for alternative loss computation)

        Returns:
            If return_readout=False and return_prelif=False: spike features
            If return_readout=True: (spike_features, readout_features)
            If return_prelif=True: (spike_features, pre_lif_features) or (spike_features, pre_lif_features, readout_features)
        """
        # IMPORTANT: Gradient detachment for local learning
        # When detach_grad=True, we stop gradient flow from this layer to previous layers.
        # This enforces locality: each layer updates independently based only on local signals.
        # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
        if detach_grad:
            x = x.detach()
            
        if not act: 
            # No activation case
            linear_output = self.forward_layer(x)
            if return_prelif and return_readout:
                return linear_output, linear_output, None  # (spike, prelif, readout)
            elif return_prelif:
                return linear_output, linear_output  # (spike, prelif) - both same when no activation
            elif return_readout:
                return linear_output, None  # (spike, readout)
            return linear_output
        
        # Forward pass following bidirectional network structure
        if self.fw_bn == 0: 
            x = self.forward_layer(x)
        elif self.fw_bn == 1: 
            x = self.forward_layer(self.forward_bn(x))
        elif self.fw_bn == 2: 
            x = self.forward_bn(self.forward_layer(x))
            
        # Store pre-LIF value for potential return
        pre_lif = x.clone() if return_prelif else None
            
        # Apply LIF activation - this gives spikes
        spike_features = self.forward_lif(x)
        
        # Handle different return combinations
        if return_prelif and return_readout:
            # Return spike, prelif, and readout
            if self.use_readout:
                readout_features = self.forward_readout_conv(spike_features)
                readout_features = self.forward_readout_bn(readout_features)
                return spike_features, pre_lif, readout_features
            else:
                return spike_features, pre_lif, None
        elif return_prelif:
            # Return spike and prelif only
            return spike_features, pre_lif
        elif return_readout:
            # Original readout logic
            if self.use_readout:
                readout_features = self.forward_readout_conv(spike_features)
                readout_features = self.forward_readout_bn(readout_features)
                return spike_features, readout_features
            else:
                return spike_features, None
        
        return spike_features
    
    def reverse(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
        """
        Reverse pass with optional readout branch and pre-LIF values

        Args:
            x: Input tensor [T, N, C, H, W]
            detach_grad: Whether to detach gradients for local learning (default: False)
                        When True, gradients are blocked from backpropagating to previous layers,
                        ensuring locality in the learning algorithm (key for BSD/local learning)
            act: Whether to apply activation
            return_readout: Whether to return readout features (for BSD loss)
            return_prelif: Whether to return pre-LIF values (for alternative loss computation)

        Returns:
            If return_readout=False and return_prelif=False: spike features
            If return_readout=True: (spike_features, readout_features)
            If return_prelif=True: (spike_features, pre_lif_features) or (spike_features, pre_lif_features, readout_features)
        """
        # IMPORTANT: Gradient detachment for local learning
        # When detach_grad=True, we stop gradient flow from this layer to previous layers.
        # This enforces locality: each layer updates independently based only on local signals.
        # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
        if detach_grad:
            x = x.detach()
            
        if not act: 
            # No activation case
            linear_output = self.backward_layer(x)
            if return_prelif and return_readout:
                return linear_output, linear_output, None  # (spike, prelif, readout)
            elif return_prelif:
                return linear_output, linear_output  # (spike, prelif) - both same when no activation
            elif return_readout:
                return linear_output, None  # (spike, readout)
            return linear_output
        
        # Reverse pass following bidirectional network structure
        if self.bw_bn == 0: 
            x = self.backward_layer(x)
        elif self.bw_bn == 1: 
            x = self.backward_layer(self.backward_bn(x))
        elif self.bw_bn == 2: 
            x = self.backward_bn(self.backward_layer(x))
            
        # Store pre-LIF value for potential return
        pre_lif = x.clone() if return_prelif else None
            
        # Apply LIF activation - this gives spikes
        spike_features = self.backward_lif(x)
        
        # Handle different return combinations
        if return_prelif and return_readout:
            # Return spike, prelif, and readout
            if self.use_readout:
                readout_features = self.backward_readout_conv(spike_features)
                readout_features = self.backward_readout_bn(readout_features)
                return spike_features, pre_lif, readout_features
            else:
                return spike_features, pre_lif, None
        elif return_prelif:
            # Return spike and prelif only
            return spike_features, pre_lif
        elif return_readout:
            # Original readout logic
            if self.use_readout:
                readout_features = self.backward_readout_conv(spike_features)
                readout_features = self.backward_readout_bn(readout_features)
                return spike_features, readout_features
            else:
                return spike_features, None
        
        return spike_features


class SpikingC2ConvTrans(nn.Module):
    """Spiking version of C2ConvTrans with optional readout branches"""
    def __init__(self, args, in_channels, out_channels, kernel, fw_bias=False, bw_bias=False):
        super(SpikingC2ConvTrans, self).__init__()
        
        # Get backend
        backend = getattr(args, "backend", "torch")
        self.use_readout = getattr(args, "use_spike_readout", False)
        
        # Convolutional layers (exactly matching original C2ConvTrans)
        self.forward_layer = layer.Conv2d(in_channels, out_channels, kernel, stride=1, 
                                         padding=kernel//2, bias=fw_bias, step_mode="m")
        self.backward_layer = layer.Conv2d(out_channels, in_channels, kernel, stride=2, 
                                          padding=kernel//2, bias=bw_bias, step_mode="m")
        
        # Upsample for forward pass (matching C2ConvTrans forward upsample)
        self.upsample = layer.Upsample(scale_factor=2, mode="nearest", step_mode="m")
        
        # Batchnorm settings (following bidirectional network structure)
        self.fw_bn = args.fw_bn
        self.bw_bn = args.bw_bn
        
        # BatchNorm setup exactly matching bidirectional network structure (with affine parameter)
        if self.fw_bn == 1: 
            self.forward_bn = layer.BatchNorm2d(in_channels, affine=args.bn_affine, step_mode="m")
        elif self.fw_bn == 2: 
            self.forward_bn = layer.BatchNorm2d(out_channels, affine=args.bn_affine, step_mode="m")
        if self.bw_bn == 1: 
            self.backward_bn = layer.BatchNorm2d(out_channels, affine=args.bn_affine, step_mode="m")
        elif self.bw_bn == 2: 
            self.backward_bn = layer.BatchNorm2d(in_channels, affine=args.bn_affine, step_mode="m")
        
        # LIF neurons for activation - using configurable surrogate function
        surrogate_func = get_surrogate_function(args)
        self.forward_lif = neuron.LIFNode(
            tau=args.tau,
            v_threshold=args.v_threshold_forward,
            surrogate_function=surrogate_func,
            step_mode="m",
            backend=backend
        )
        # Create separate instance for backward path
        surrogate_func_bw = get_surrogate_function(args)
        self.backward_lif = neuron.LIFNode(
            tau=args.tau,
            v_threshold=args.v_threshold_backward,
            surrogate_function=surrogate_func_bw,
            step_mode="m",
            backend=backend
        )
        
        # Readout branches (for training only) - convert spikes back to floats for BSD loss
        if self.use_readout:
            # Get expansion factor for richer feature representation
            expand_factor = getattr(args, "readout_expand_factor", 1)
            
            # For conv layers: use 1x1 conv to expand channel dimensions while keeping spatial dimensions
            forward_out_channels = int(out_channels * expand_factor)
            backward_out_channels = int(in_channels * expand_factor)
            
            self.forward_readout_conv = layer.Conv2d(out_channels, forward_out_channels, 1, bias=False, step_mode="m")
            self.forward_readout_bn = layer.BatchNorm2d(forward_out_channels, affine=args.bn_affine, step_mode="m")
            
            self.backward_readout_conv = layer.Conv2d(in_channels, backward_out_channels, 1, bias=False, step_mode="m")
            self.backward_readout_bn = layer.BatchNorm2d(backward_out_channels, affine=args.bn_affine, step_mode="m")
        
        # Bias initialization (for the bidirectional network)
        if args.bias_init == "zero" and self.forward_layer.bias is not None:
            self.forward_layer.bias.data.zero_()
        if args.bias_init == "zero" and self.backward_layer.bias is not None:
            self.backward_layer.bias.data.zero_()
    
    def get_parameters(self):
        self.forward_params = list(self.forward_layer.parameters())
        self.backward_params = list(self.backward_layer.parameters())
        
        # Include readout parameters if enabled
        if self.use_readout:
            self.forward_params += list(self.forward_readout_conv.parameters())
            self.forward_params += list(self.forward_readout_bn.parameters())
            self.backward_params += list(self.backward_readout_conv.parameters())
            self.backward_params += list(self.backward_readout_bn.parameters())
        
        return self.forward_params, self.backward_params
    
    def forward(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
        """
        Forward pass with optional readout branch and pre-LIF values

        Args:
            x: Input tensor [T, N, C, H, W]
            detach_grad: Whether to detach gradients for local learning (default: False)
                        When True, gradients are blocked from backpropagating to previous layers,
                        ensuring locality in the learning algorithm (key for BSD/local learning)
            act: Whether to apply activation
            return_readout: Whether to return readout features (for BSD loss)
            return_prelif: Whether to return pre-LIF values (for alternative loss computation)

        Returns:
            If return_readout=False and return_prelif=False: spike features
            If return_readout=True: (spike_features, readout_features)
            If return_prelif=True: (spike_features, pre_lif_features) or (spike_features, pre_lif_features, readout_features)
        """
        # IMPORTANT: Gradient detachment for local learning
        # When detach_grad=True, we stop gradient flow from this layer to previous layers.
        # This enforces locality: each layer updates independently based only on local signals.
        # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
        if detach_grad:
            x = x.detach()
        
        # Upsample first (exactly matching original C2ConvTrans)
        x = self.upsample(x)
            
        if not act: 
            linear_output = self.forward_layer(x)
            if return_prelif and return_readout:
                return linear_output, linear_output, None
            elif return_prelif:
                return linear_output, linear_output
            elif return_readout:
                return linear_output, None
            return linear_output
        
        # Forward pass following bidirectional network structure
        if self.fw_bn == 0: 
            x = self.forward_layer(x)
        elif self.fw_bn == 1: 
            x = self.forward_layer(self.forward_bn(x))
        elif self.fw_bn == 2: 
            x = self.forward_bn(self.forward_layer(x))
            
        pre_lif = x.clone() if return_prelif else None
        spike_features = self.forward_lif(x)
        
        # Handle different return combinations
        if return_prelif and return_readout:
            if self.use_readout:
                readout_features = self.forward_readout_conv(spike_features)
                readout_features = self.forward_readout_bn(readout_features)
                return spike_features, pre_lif, readout_features
            else:
                return spike_features, pre_lif, None
        elif return_prelif:
            return spike_features, pre_lif
        elif return_readout:
            if self.use_readout:
                readout_features = self.forward_readout_conv(spike_features)
                readout_features = self.forward_readout_bn(readout_features)
                return spike_features, readout_features
            else:
                return spike_features, None
        
        return spike_features
    
    def reverse(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
        """
        Reverse pass with optional readout branch and pre-LIF values

        Args:
            x: Input tensor [T, N, C, H, W]
            detach_grad: Whether to detach gradients for local learning (default: False)
                        When True, gradients are blocked from backpropagating to previous layers,
                        ensuring locality in the learning algorithm (key for BSD/local learning)
            act: Whether to apply activation
            return_readout: Whether to return readout features (for BSD loss)
            return_prelif: Whether to return pre-LIF values (for alternative loss computation)

        Returns:
            If return_readout=False and return_prelif=False: spike features
            If return_readout=True: (spike_features, readout_features)
            If return_prelif=True: (spike_features, pre_lif_features) or (spike_features, pre_lif_features, readout_features)
        """
        # IMPORTANT: Gradient detachment for local learning
        # When detach_grad=True, we stop gradient flow from this layer to previous layers.
        # This enforces locality: each layer updates independently based only on local signals.
        # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
        if detach_grad:
            x = x.detach()
        
        if not act: 
            # No activation case - stride=2 conv (matches C2ConvTrans.reverse)
            linear_output = self.backward_layer(x)
            if return_prelif and return_readout:
                return linear_output, linear_output, None
            elif return_prelif:
                return linear_output, linear_output
            elif return_readout:
                return linear_output, None
            return linear_output
        
        # Reverse pass following bidirectional network structure
        if self.bw_bn == 0: 
            x = self.backward_layer(x)
        elif self.bw_bn == 1: 
            x = self.backward_layer(self.backward_bn(x))
        elif self.bw_bn == 2: 
            x = self.backward_bn(self.backward_layer(x))
            
        pre_lif = x.clone() if return_prelif else None
        spike_features = self.backward_lif(x)
        
        # Handle different return combinations
        if return_prelif and return_readout:
            if self.use_readout:
                readout_features = self.backward_readout_conv(spike_features)
                readout_features = self.backward_readout_bn(readout_features)
                return spike_features, pre_lif, readout_features
            else:
                return spike_features, pre_lif, None
        elif return_prelif:
            return spike_features, pre_lif
        elif return_readout:
            if self.use_readout:
                readout_features = self.backward_readout_conv(spike_features)
                readout_features = self.backward_readout_bn(readout_features)
                return spike_features, readout_features
            else:
                return spike_features, None

        return spike_features


class SpikingC2ConvWithTranspose(nn.Module):
    """
    Spiking CNN layer with stride=2 convolution in forward and ConvTranspose2d in reverse.

    This layer is designed for the user's experiment:
    - Forward: Conv2d(stride=2) → BN(optional) → LIF (dimension reduction)
    - Reverse: ConvTranspose2d(stride=2) → BN(optional) → LIF (learnable upsampling)
    - No separate pooling layer needed
    """
    def __init__(self, args, in_channels, out_channels, kernel, fw_bias=False, bw_bias=False):
        super(SpikingC2ConvWithTranspose, self).__init__()

        # Get backend
        backend = getattr(args, 'backend', 'torch')
        self.use_readout = getattr(args, 'use_spike_readout', False)

        # Forward: Conv2d with stride=2 (dimension reduction, replaces pooling)
        self.forward_layer = layer.Conv2d(in_channels, out_channels, kernel, stride=2,
                                         padding=kernel//2, bias=True, step_mode='m')

        # Reverse: ConvTranspose2d with stride=2 (learnable dimension increase)
        # Using spikingjelly's ConvTranspose2d which supports step_mode
        self.backward_layer = layer.ConvTranspose2d(out_channels, in_channels, kernel, stride=2,
                                                    padding=kernel//2, output_padding=1, bias=True, step_mode='m')

        # Batchnorm settings
        self.fw_bn = args.fw_bn
        self.bw_bn = args.bw_bn

        # BatchNorm setup
        if self.fw_bn == 1:
            self.forward_bn = layer.BatchNorm2d(in_channels, affine=args.bn_affine, step_mode='m')
        elif self.fw_bn == 2:
            self.forward_bn = layer.BatchNorm2d(out_channels, affine=args.bn_affine, step_mode='m')
        if self.bw_bn == 1:
            self.backward_bn = layer.BatchNorm2d(out_channels, affine=args.bn_affine, step_mode='m')
        elif self.bw_bn == 2:
            self.backward_bn = layer.BatchNorm2d(in_channels, affine=args.bn_affine, step_mode='m')

        # LIF neurons for activation
        surrogate_func = get_surrogate_function(args)
        self.forward_lif = neuron.LIFNode(
            tau=args.tau,
            v_threshold=args.v_threshold_forward,
            surrogate_function=surrogate_func,
            step_mode='m',
            backend=backend
        )
        surrogate_func_bw = get_surrogate_function(args)
        self.backward_lif = neuron.LIFNode(
            tau=args.tau,
            v_threshold=args.v_threshold_backward,
            surrogate_function=surrogate_func_bw,
            step_mode='m',
            backend=backend
        )

        # Readout branches
        if self.use_readout:
            expand_factor = getattr(args, 'readout_expand_factor', 1)
            forward_out_channels = int(out_channels * expand_factor)
            backward_out_channels = int(in_channels * expand_factor)

            self.forward_readout_conv = layer.Conv2d(out_channels, forward_out_channels, 1, bias=False, step_mode='m')
            self.forward_readout_bn = layer.BatchNorm2d(forward_out_channels, affine=args.bn_affine, step_mode='m')

            self.backward_readout_conv = layer.Conv2d(in_channels, backward_out_channels, 1, bias=False, step_mode='m')
            self.backward_readout_bn = layer.BatchNorm2d(backward_out_channels, affine=args.bn_affine, step_mode='m')

        # Bias initialization
        if args.bias_init == "zero" and self.forward_layer.bias is not None:
            self.forward_layer.bias.data.zero_()
        if args.bias_init == "zero" and self.backward_layer.bias is not None:
            self.backward_layer.bias.data.zero_()

    def get_parameters(self):
        self.forward_params = list(self.forward_layer.parameters())
        self.backward_params = list(self.backward_layer.parameters())

        if self.use_readout:
            self.forward_params += list(self.forward_readout_conv.parameters())
            self.forward_params += list(self.forward_readout_bn.parameters())
            self.backward_params += list(self.backward_readout_conv.parameters())
            self.backward_params += list(self.backward_readout_bn.parameters())

        return self.forward_params, self.backward_params

    def forward(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
        """
        Forward pass: Conv(stride=2) → BN(optional) → LIF

        Args:
            x: Input tensor [T, N, C, H, W]
            detach_grad: Whether to detach gradients for local learning (default: False)
                        When True, gradients are blocked from backpropagating to previous layers,
                        ensuring locality in the learning algorithm (key for BSD/local learning)
            act: Whether to apply activation
            return_readout: Whether to return readout features (for BSD loss)
            return_prelif: Whether to return pre-LIF values (for alternative loss computation)

        Returns:
            If return_readout=False and return_prelif=False: spike features
            If return_readout=True: (spike_features, readout_features)
            If return_prelif=True: (spike_features, pre_lif_features) or (spike_features, pre_lif_features, readout_features)
        """
        # IMPORTANT: Gradient detachment for local learning
        # When detach_grad=True, we stop gradient flow from this layer to previous layers.
        # This enforces locality: each layer updates independently based only on local signals.
        # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
        if detach_grad:
            x = x.detach()

        # Apply Conv + BN (with 0/1/2 options)
        if self.fw_bn == 0:
            # No BN: just Conv
            x = self.forward_layer(x)
        elif self.fw_bn == 1:
            # BN before Conv
            x = self.forward_layer(self.forward_bn(x))
        elif self.fw_bn == 2:
            # BN after Conv
            x = self.forward_bn(self.forward_layer(x))

        if not act:
            # No activation case - return after conv+bn, before LIF
            if return_prelif and return_readout:
                return x, x, None
            elif return_prelif:
                return x, x
            elif return_readout:
                return x, None
            return x

        # Store prelif (after conv+bn, before LIF) and apply LIF
        pre_lif = x.clone() if return_prelif else None
        spike_features = self.forward_lif(x)

        # Handle return combinations
        if return_prelif and return_readout:
            if self.use_readout:
                readout_features = self.forward_readout_conv(spike_features)
                readout_features = self.forward_readout_bn(readout_features)
                return spike_features, pre_lif, readout_features
            else:
                return spike_features, pre_lif, None
        elif return_prelif:
            return spike_features, pre_lif
        elif return_readout:
            if self.use_readout:
                readout_features = self.forward_readout_conv(spike_features)
                readout_features = self.forward_readout_bn(readout_features)
                return spike_features, readout_features
            else:
                return spike_features, None

        return spike_features

    def reverse(self, x, detach_grad=False, act=True, return_readout=False, return_prelif=False):
        """
        Reverse pass with ConvTranspose2d stride=2

        Args:
            x: Input tensor [T, N, C, H, W]
            detach_grad: Whether to detach gradients for local learning (default: False)
                        When True, gradients are blocked from backpropagating to previous layers,
                        ensuring locality in the learning algorithm (key for BSD/local learning)
            act: Whether to apply activation
            return_readout: Whether to return readout features (for BSD loss)
            return_prelif: Whether to return pre-LIF values (for alternative loss computation)

        Returns:
            If return_readout=False and return_prelif=False: spike features
            If return_readout=True: (spike_features, readout_features)
            If return_prelif=True: (spike_features, pre_lif_features) or (spike_features, pre_lif_features, readout_features)
        """
        # IMPORTANT: Gradient detachment for local learning
        # When detach_grad=True, we stop gradient flow from this layer to previous layers.
        # This enforces locality: each layer updates independently based only on local signals.
        # This is a critical component of the BSD (Bidirectional Spike Distillation) algorithm.
        if detach_grad:
            x = x.detach()

        # Apply ConvTranspose2d + BN (with 0/1/2 options)
        if self.bw_bn == 0:
            x = self.backward_layer(x)
        elif self.bw_bn == 1:
            x = self.backward_layer(self.backward_bn(x))
        elif self.bw_bn == 2:
            x = self.backward_bn(self.backward_layer(x))

        if not act:
            # No activation case - return after ConvTranspose2d+bn, before LIF
            if return_prelif and return_readout:
                return x, x, None
            elif return_prelif:
                return x, x
            elif return_readout:
                return x, None
            return x

        # Store pre-LIF value for potential return
        pre_lif = x.clone() if return_prelif else None

        # Apply LIF activation - this gives spikes
        spike_features = self.backward_lif(x)

        # Handle return combinations
        if return_prelif and return_readout:
            if self.use_readout:
                readout_features = self.backward_readout_conv(spike_features)
                readout_features = self.backward_readout_bn(readout_features)
                return spike_features, pre_lif, readout_features
            else:
                return spike_features, pre_lif, None
        elif return_prelif:
            return spike_features, pre_lif
        elif return_readout:
            if self.use_readout:
                readout_features = self.backward_readout_conv(spike_features)
                readout_features = self.backward_readout_bn(readout_features)
                return spike_features, readout_features
            else:
                return spike_features, None

        return spike_features

